# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This file contains code related to network configuration.

It also includes encryption, network isolation, and VPC configurations.
"""
from __future__ import absolute_import

from typing import Union, Optional, List

from sagemaker.core.helper.pipeline_variable import PipelineVariable


class NetworkConfig(object):
    """Accepts network configuration parameters for conversion to request dict.

    The `_to_request_dict` provides a method to turn the parameters into a dict.
    """

    def __init__(
        self,
        enable_network_isolation: Union[bool, PipelineVariable] = None,
        security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
        subnets: Optional[List[Union[str, PipelineVariable]]] = None,
        encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None,
    ):
        """Initialize a ``NetworkConfig`` instance.

        NetworkConfig accepts network configuration parameters and provides a method to turn
        these parameters into a dictionary.

        Args:
            enable_network_isolation (bool or PipelineVariable): Boolean that determines
                whether to enable network isolation.
            security_group_ids (list[str] or list[PipelineVariable]): A list of strings representing
                security group IDs.
            subnets (list[str] or list[PipelineVariable]): A list of strings representing subnets.
            encrypt_inter_container_traffic (bool or PipelineVariable): Boolean that determines
                whether to encrypt inter-container traffic. Default value is None.
        """
        self.enable_network_isolation = enable_network_isolation
        self.security_group_ids = security_group_ids
        self.subnets = subnets
        self.encrypt_inter_container_traffic = encrypt_inter_container_traffic

    def _to_request_dict(self):
        """Generates a request dictionary using the parameters provided to the class."""
        # Enable Network Isolation should default to False if it is not provided.
        enable_network_isolation = (
            False if self.enable_network_isolation is None else self.enable_network_isolation
        )
        network_config_request = {"EnableNetworkIsolation": enable_network_isolation}

        if self.encrypt_inter_container_traffic is not None:
            network_config_request["EnableInterContainerTrafficEncryption"] = (
                self.encrypt_inter_container_traffic
            )

        if self.security_group_ids is not None or self.subnets is not None:
            network_config_request["VpcConfig"] = {}

        if self.security_group_ids is not None:
            network_config_request["VpcConfig"]["SecurityGroupIds"] = self.security_group_ids

        if self.subnets is not None:
            network_config_request["VpcConfig"]["Subnets"] = self.subnets

        return network_config_request


# VPC Utilities (merged from vpc_utils.py)

SUBNETS_KEY = "Subnets"
SECURITY_GROUP_IDS_KEY = "SecurityGroupIds"
VPC_CONFIG_KEY = "VpcConfig"

# A global constant value for methods which can optionally override VpcConfig
VPC_CONFIG_DEFAULT = "VPC_CONFIG_DEFAULT"


def to_dict(subnets, security_group_ids):
    """Prepares a VpcConfig dict containing keys 'Subnets' and 'SecurityGroupIds'.

    This is the dict format expected by SageMaker CreateTrainingJob and CreateModel APIs.
    See https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html

    Args:
        subnets (list): list of subnet IDs to use in VpcConfig
        security_group_ids (list): list of security group IDs to use in
            VpcConfig

    Returns:
        A VpcConfig dict containing keys 'Subnets' and 'SecurityGroupIds' If
        either or both parameters are None, returns None
    """
    if subnets is None or security_group_ids is None:
        return None
    return {SUBNETS_KEY: subnets, SECURITY_GROUP_IDS_KEY: security_group_ids}


def from_dict(vpc_config, do_sanitize=False):
    """Extracts subnets and security group ids as lists from a VpcConfig dict

    Args:
        vpc_config (dict): a VpcConfig dict containing 'Subnets' and
            'SecurityGroupIds'
        do_sanitize (bool): whether to sanitize the VpcConfig dict before
            extracting values

    Returns:
        Tuple of lists as (subnets, security_group_ids) If vpc_config parameter
        is None, returns (None, None)

    Raises:
        * ValueError if sanitize enabled and vpc_config is invalid

        * KeyError if sanitize disabled and vpc_config is missing key(s)
    """
    if do_sanitize:
        vpc_config = sanitize(vpc_config)
    if vpc_config is None:
        return None, None
    return vpc_config[SUBNETS_KEY], vpc_config[SECURITY_GROUP_IDS_KEY]


def sanitize(vpc_config):
    """Checks and removes unexpected keys from VpcConfig or raises error for violations.

    Checks that an instance of VpcConfig has the expected keys and values,
    removes unexpected keys, and raises ValueErrors if any expectations are
    violated.

    Args:
        vpc_config (dict): a VpcConfig dict containing 'Subnets' and
            'SecurityGroupIds'

    Returns:
        A valid VpcConfig dict containing only 'Subnets' and 'SecurityGroupIds'
        from the vpc_config parameter If vpc_config parameter is None, returns
        None

    Raises:
        ValueError if any expectations are violated:
            * vpc_config must be a non-empty dict
            * vpc_config must have key `Subnets` and the value must be a non-empty list
            * vpc_config must have key `SecurityGroupIds` and the value must be a non-empty list
    """
    if vpc_config is None:
        return vpc_config
    if not isinstance(vpc_config, dict):
        raise ValueError("vpc_config is not a dict: {}".format(vpc_config))
    if not vpc_config:
        raise ValueError("vpc_config is empty")

    subnets = vpc_config.get(SUBNETS_KEY)
    if subnets is None:
        raise ValueError("vpc_config is missing key: {}".format(SUBNETS_KEY))
    if not isinstance(subnets, list):
        raise ValueError("vpc_config value for {} is not a list: {}".format(SUBNETS_KEY, subnets))
    if not subnets:
        raise ValueError("vpc_config value for {} is empty".format(SUBNETS_KEY))

    security_group_ids = vpc_config.get(SECURITY_GROUP_IDS_KEY)
    if security_group_ids is None:
        raise ValueError("vpc_config is missing key: {}".format(SECURITY_GROUP_IDS_KEY))
    if not isinstance(security_group_ids, list):
        raise ValueError(
            "vpc_config value for {} is not a list: {}".format(
                SECURITY_GROUP_IDS_KEY, security_group_ids
            )
        )
    if not security_group_ids:
        raise ValueError("vpc_config value for {} is empty".format(SECURITY_GROUP_IDS_KEY))

    return to_dict(subnets, security_group_ids)
